import torch
import torchvision.transforms as T


def normalization(x: torch.Tensor) -> torch.Tensor:
    """ 均值方差规范化

    :param x: images, shape = (B, C, H, W)
    :return:
    """
    normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    x_normalized = normalize(x)
    return x_normalized


def de_normalization(x: torch.Tensor) -> torch.Tensor:
    """ 反均值方差规范化

    :param x: images, shape = (B, C, H, W)
    :return:
    """
    normalize = T.Normalize([-2.1179, -2.0357, -1.8044], [4.3668, 4.4643, 4.444])
    x_normalized = torch.clamp(normalize(x), min=0.0, max=1.0)
    return x_normalized
